Projekt z Podstaw Sztucznej Inteligencji

W projekcie wykorzystujemy informacje zebrane podczas gry w Affective SpaceShooter 2.

W pierwszym etapie analizujemy dane pod kątem zależności między wynikiem gry a poszczególnymi cechami osobowości.

W drugiej części dokonujemy uczenia nadzorowanego - na podstawie zbioru cech osobowości staramy się przewidzieć średni wynik punktowy gry reprezentowany przez klasy "low" i "medium".

Preprocessing

In [1]:
import pandas as pd
import numpy as np
import csv


md = pd.read_csv("BIRAFFE-metadata.csv", sep=';')
#usunięcie tych rekordów gdzie osoba nie ma danych z gry w space
md = md[pd.notnull(md['SPACE'])]
md = md[pd.notnull(md['OPENNESS'])]
#zostawienie id tych osób, bo pliki mają w nazwie id
ids = md['ID'].values

import csv
import json    
print("start")
pd.set_option('display.max_columns', 500)
#import metadata
data = pd.read_csv("merged_scores.csv", sep=',')
data = data[pd.notnull(data['OPENNESS'])]
data = data[pd.notnull(data['CONSCIENTIOUSNESS'])]
data = data[pd.notnull(data['EXTRAVERSION'])]
data = data[pd.notnull(data['AGREEABLENESS'])]
data = data[pd.notnull(data['NEUROTICISM'])]

data=data[data.Score == 'GameOver']
#data.head(25)
type(data)
mean_c=[]

with open('mean_scores.csv', 'w', newline='') as csvfile:
    #nazwy kolumn- wszystkie z plików json
    fieldnames = ["P_ID","OPENNESS","CONSCIENTIOUSNESS","EXTRAVERSION","AGREEABLENESS","NEUROTICISM","Mean"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    #ustawienie nagłówków
    writer.writeheader()
    for my_id in ids:
        data1=data.loc[data['P_ID'] == my_id]
        nr=data1.shape[0]
        #print(data1)
        score=data1['Value'].sum()
        if (score):
            mean = score/nr
            #print(mean)
            new_data={ 'P_ID': data1['P_ID'].iloc[0], 'OPENNESS': data1['OPENNESS'].iloc[0], 'CONSCIENTIOUSNESS': data1['CONSCIENTIOUSNESS'].iloc[0],'EXTRAVERSION':  data1['EXTRAVERSION'].iloc[0],'AGREEABLENESS': data1['AGREEABLENESS'].iloc[0],'NEUROTICISM': data1['NEUROTICISM'].iloc[0],'Mean': mean}
            writer.writerow(new_data)
start

Obliczamy średni wynik gry dla każdej osoby, który będziemy zestawiać z cechami osobowości.

Analiza danych

In [2]:
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
import pandas as pd
import numpy as np
import csv

pd.set_option('display.max_columns', 500)
#import metadata
data = pd.read_csv("mean_scores.csv", sep=',')
data.head()
fig = px.scatter(data, x = 'OPENNESS', y = 'Mean', title='test')
fig.show()
fig1 = px.scatter(data, x = 'CONSCIENTIOUSNESS', y = 'Mean', title='test')
fig1.show()
fig2 = px.scatter(data, x = 'EXTRAVERSION', y = 'Mean', title='test')
fig2.show()
fig3 = px.scatter(data, x = 'AGREEABLENESS', y = 'Mean', title='test')
fig3.show()
fig4 = px.scatter(data, x = 'NEUROTICISM', y = 'Mean', title='test')
fig4.show()

Powyższe wykresy przedstawiają korelację średniego wyniku gry z każdą z cech osobowości. Można z nich wywnioskować, że zależności nie istnieją.

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(5,5))
sns.heatmap(data.corr(), vmax=1.0, center=0, fmt='.2f', linewidths=.9, annot=True,cbar_kws={"shrink": .70})
plt.show();

Powyższe zestawienie również pokazuje, że korelacja między badanymi elementami jest niska.

In [4]:
data.boxplot(column=['OPENNESS', 'CONSCIENTIOUSNESS', 'EXTRAVERSION', 'AGREEABLENESS', 'NEUROTICISM'], rot=45)
Out[4]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a3938fd550>
In [5]:
data.boxplot(column=['Mean'], rot=45)
print("Większość wyników znajduje się pomiędzy około 400 a 1100 punktów, grupa o średniej <=2000 to najbardziej wiarygodna grupa testowa")
Większość wyników znajduje się pomiędzy około 400 a 1100 punktów, grupa o średniej <=2000 to najbardziej wiarygodna grupa testowa
In [6]:
data_trimmed=data.loc[data['Mean'] <= 2000]

fig, ax = plt.subplots(figsize=(5,5))
sns.heatmap(data_trimmed.corr(), vmax=1.0, center=0, fmt='.2f', linewidths=.9, annot=True,cbar_kws={"shrink": .70})
plt.show();

Zestawienie korelacji dla danych testowych, w których średni wynik gry wyniósł <= 2000 punktów. Korelacja między badanymi elementami nadal jest niska.

In [7]:
pd.set_option('display.max_columns', 500)
#import metadata
data = pd.read_csv("mean_scores.csv", sep=',')
data.head()
fig = px.scatter(data_trimmed, x = 'OPENNESS', y = 'Mean', title='test')
fig.show()
fig1 = px.scatter(data_trimmed, x = 'CONSCIENTIOUSNESS', y = 'Mean', title='test')
fig1.show()
fig2 = px.scatter(data_trimmed, x = 'EXTRAVERSION', y = 'Mean', title='test')
fig2.show()
fig3 = px.scatter(data_trimmed, x = 'AGREEABLENESS', y = 'Mean', title='test')
fig3.show()
fig4 = px.scatter(data_trimmed, x = 'NEUROTICISM', y = 'Mean', title='test')
fig4.show()

Niską korelację potwierdzają również wykresy.

In [8]:
from sklearn.model_selection import train_test_split 
from sklearn import metrics
from sklearn import linear_model as ln


def showBarPlot(X, Y, title):
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

    regressor = ln.LinearRegression()  
    regressor.fit(X_train, y_train) 
    print(regressor.score(X_train, y_train))


    print(regressor.intercept_)
    print(regressor.coef_)


    y_pred = regressor.predict(X_test)

    df1 = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred})
    df2 = df1.head(25)


    df2.plot(kind='bar', figsize=(10,5))
    plt.grid(which='major', linestyle='-', linewidth='0.5', color='green')
    plt.grid(which='minor', linestyle=':', linewidth='0.5', color='black')
    plt.show()
    
    print(title)
    print('Mean Absolute Error:', metrics.mean_absolute_error(y_test, y_pred))  
    print('Mean Squared Error:', metrics.mean_squared_error(y_test, y_pred))  
    print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, y_pred)))
    print('\n\n')

    
    
X = data[['OPENNESS']].values
Y = data['Mean'].values 
showBarPlot(X,Y,'OPENNESS - Mean')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['Mean'].values 
showBarPlot(X,Y,'CONSCIENTIOUSNESS - Mean')

X = data[['EXTRAVERSION']].values
Y = data['Mean'].values 
showBarPlot(X,Y,'EXTRAVERSION - Mean')

X = data[['AGREEABLENESS']].values
Y = data['Mean'].values 
showBarPlot(X,Y,'AGREEABLENESS - Mean')

X = data[['NEUROTICISM']].values
Y = data['Mean'].values 
showBarPlot(X,Y,'NEUROTICISM - Mean')
0.0036420652538303733
765.8645885671808
[26.00805618]
OPENNESS - Mean
Mean Absolute Error: 583.736574556191
Mean Squared Error: 618070.2400471191
Root Mean Squared Error: 786.1744336005331



0.0013540911675161693
724.3240302389381
[11.47845344]
CONSCIENTIOUSNESS - Mean
Mean Absolute Error: 723.3022509129872
Mean Squared Error: 1303843.7113480025
Root Mean Squared Error: 1141.859759930265



0.002271190205002993
744.1514840736597
[16.20205046]
EXTRAVERSION - Mean
Mean Absolute Error: 641.4666495671707
Mean Squared Error: 960820.1478761613
Root Mean Squared Error: 980.214337722195



0.024552433192815526
564.0376586937914
[58.09360786]
AGREEABLENESS - Mean
Mean Absolute Error: 419.3962878521661
Mean Squared Error: 227817.41050275232
Root Mean Squared Error: 477.3022213469704



0.0006685559451169443
831.4668037935967
[8.00819191]
NEUROTICISM - Mean
Mean Absolute Error: 466.41469643612055
Mean Squared Error: 293099.8063459073
Root Mean Squared Error: 541.3869284956068



Niska korelacja wpływa na złe wyniki modelu uczonego za pomocą regresji liniowej, który szuka zależności między średnim wynikiem a poszczególnymi cechami osobowości.

In [9]:
X = data[['OPENNESS']].values
Y = data['CONSCIENTIOUSNESS'].values 
showBarPlot(X,Y,'OPENNESS - CONSCIENTIOUSNESS')

X = data[['OPENNESS']].values
Y = data['EXTRAVERSION'].values 
showBarPlot(X,Y,'OPENNESS - EXTRAVERSION')

X = data[['OPENNESS']].values
Y = data['AGREEABLENESS'].values 
showBarPlot(X,Y,'OPENNESS - AGREEABLENESS')

X = data[['OPENNESS']].values
Y = data['NEUROTICISM'].values 
showBarPlot(X,Y,'OPENNESS - NEUROTICISM')
0.008578537865989833
4.811401266023664
[0.10600744]
OPENNESS - CONSCIENTIOUSNESS
Mean Absolute Error: 1.7979375686938328
Mean Squared Error: 4.7738450218791115
Root Mean Squared Error: 2.184913046754747



0.06402928086216864
3.649648660518226
[0.30511638]
OPENNESS - EXTRAVERSION
Mean Absolute Error: 1.8438991362904409
Mean Squared Error: 4.967205421520904
Root Mean Squared Error: 2.228722822946116



0.0007477731981171409
6.304230307876849
[-0.03053179]
OPENNESS - AGREEABLENESS
Mean Absolute Error: 2.2293306677329072
Mean Squared Error: 7.190473659136638
Root Mean Squared Error: 2.6815058566291885



0.010836789504800937
4.77464159928568
[0.13636589]
OPENNESS - NEUROTICISM
Mean Absolute Error: 2.7757560725565096
Mean Squared Error: 9.792270392785213
Root Mean Squared Error: 3.129260358740578



In [10]:
X = data[['CONSCIENTIOUSNESS']].values
Y = data['OPENNESS'].values 
showBarPlot(X,Y,'CONSCIENTIOUSNESS - OPENNESS')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['EXTRAVERSION'].values 
showBarPlot(X,Y,'CONSCIENTIOUSNESS - EXTRAVERSION')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['AGREEABLENESS'].values 
showBarPlot(X,Y,'CONSCIENTIOUSNESS - AGREEABLENESS')

X = data[['CONSCIENTIOUSNESS']].values
Y = data['NEUROTICISM'].values 
showBarPlot(X,Y,'CONSCIENTIOUSNESS - NEUROTICISM')
0.01617689853719828
4.721639656816015
[0.10631129]
CONSCIENTIOUSNESS - OPENNESS
Mean Absolute Error: 1.7885300634220023
Mean Squared Error: 4.631570464222543
Root Mean Squared Error: 2.1521083765048967



0.05197568095385329
3.805555179188402
[0.24102703]
CONSCIENTIOUSNESS - EXTRAVERSION
Mean Absolute Error: 1.5640448930740916
Mean Squared Error: 4.293197463384416
Root Mean Squared Error: 2.0720032488836537



0.02436274781932724
5.047725781762519
[0.16567352]
CONSCIENTIOUSNESS - AGREEABLENESS
Mean Absolute Error: 2.1489075369924917
Mean Squared Error: 6.389802650453431
Root Mean Squared Error: 2.5278058965144914



0.051255044861413745
7.064146682438972
[-0.28503065]
CONSCIENTIOUSNESS - NEUROTICISM
Mean Absolute Error: 1.8434724163888594
Mean Squared Error: 4.908868898265374
Root Mean Squared Error: 2.2155967363817304



In [11]:
X = data[['EXTRAVERSION']].values
Y = data['OPENNESS'].values 
showBarPlot(X,Y,'EXTRAVERSION - OPENNESS')

X = data[['EXTRAVERSION']].values
Y = data['CONSCIENTIOUSNESS'].values 
showBarPlot(X,Y,'EXTRAVERSION - CONSCIENTIOUSNESS')

X = data[['EXTRAVERSION']].values
Y = data['AGREEABLENESS'].values 
showBarPlot(X,Y,'EXTRAVERSION - AGREEABLENESS')

X = data[['EXTRAVERSION']].values
Y = data['NEUROTICISM'].values 
showBarPlot(X,Y,'EXTRAVERSION - NEUROTICISM')
0.05916895718131021
4.335373616687885
[0.20686025]
EXTRAVERSION - OPENNESS
Mean Absolute Error: 1.3938505859693795
Mean Squared Error: 3.38876151143414
Root Mean Squared Error: 1.8408589059007592



0.02723836974850502
4.582841120039463
[0.16405956]
EXTRAVERSION - CONSCIENTIOUSNESS
Mean Absolute Error: 1.5955201017745642
Mean Squared Error: 4.247209285877033
Root Mean Squared Error: 2.060875854067157



0.012497215809617823
5.433646272018804
[0.11842234]
EXTRAVERSION - AGREEABLENESS
Mean Absolute Error: 2.062206241051399
Mean Squared Error: 6.182975721663073
Root Mean Squared Error: 2.4865590123025583



0.13825469311780159
8.009096944613168
[-0.47302699]
EXTRAVERSION - NEUROTICISM
Mean Absolute Error: 2.3303433955334567
Mean Squared Error: 7.637710057433703
Root Mean Squared Error: 2.7636407251004433



In [12]:
X = data[['AGREEABLENESS']].values
Y = data['OPENNESS'].values 
showBarPlot(X,Y,'AGREEABLENESS - OPENNESS')

X = data[['AGREEABLENESS']].values
Y = data['CONSCIENTIOUSNESS'].values 
showBarPlot(X,Y,'AGREEABLENESS - CONSCIENTIOUSNESS')

X = data[['AGREEABLENESS']].values
Y = data['EXTRAVERSION'].values 
showBarPlot(X,Y,'AGREEABLENESS - EXTRAVERSION')

X = data[['AGREEABLENESS']].values
Y = data['NEUROTICISM'].values 
showBarPlot(X,Y,'AGREEABLENESS - NEUROTICISM')
0.00127872247054317
5.021748281184229
[0.02792199]
AGREEABLENESS - OPENNESS
Mean Absolute Error: 1.905584397362144
Mean Squared Error: 5.342175933364268
Root Mean Squared Error: 2.3113147629356474



0.006239414257850462
5.0027176574282635
[0.07577612]
AGREEABLENESS - CONSCIENTIOUSNESS
Mean Absolute Error: 1.7563820063330395
Mean Squared Error: 4.8865452718331674
Root Mean Squared Error: 2.210553159694009



0.025542712876959306
4.401270994098956
[0.14173854]
AGREEABLENESS - EXTRAVERSION
Mean Absolute Error: 2.240664245725526
Mean Squared Error: 7.756979534056743
Root Mean Squared Error: 2.7851354606296517



1.9702042874936154e-07
5.554056355371576
[-0.00053709]
AGREEABLENESS - NEUROTICISM
Mean Absolute Error: 2.4897795302409054
Mean Squared Error: 7.902013034150577
Root Mean Squared Error: 2.811051944406324



In [13]:
X = data[['NEUROTICISM']].values
Y = data['OPENNESS'].values 
showBarPlot(X,Y,'NEUROTICISM - OPENNESS')

X = data[['NEUROTICISM']].values
Y = data['CONSCIENTIOUSNESS'].values 
showBarPlot(X,Y,'NEUROTICISM - CONSCIENTIOUSNESS')

X = data[['NEUROTICISM']].values
Y = data['EXTRAVERSION'].values 
showBarPlot(X,Y,'NEUROTICISM - EXTRAVERSION')

X = data[['NEUROTICISM']].values
Y = data['AGREEABLENESS'].values 
showBarPlot(X,Y,'NEUROTICISM - AGREEABLENESS')
0.019830257070251656
4.682457616929182
[0.10591837]
NEUROTICISM - OPENNESS
Mean Absolute Error: 1.2262579270755296
Mean Squared Error: 2.6510513253787247
Root Mean Squared Error: 1.6282049396125553



0.08049145187582618
6.841170261154641
[-0.23842054]
NEUROTICISM - CONSCIENTIOUSNESS
Mean Absolute Error: 1.3616029156403853
Mean Squared Error: 2.704389173468309
Root Mean Squared Error: 1.6445027131228178



0.1357177461178074
7.02434531468402
[-0.31567449]
NEUROTICISM - EXTRAVERSION
Mean Absolute Error: 1.769912708624536
Mean Squared Error: 5.157070129456427
Root Mean Squared Error: 2.27091834495572



0.002453866743836852
6.195062200956937
[-0.04419139]
NEUROTICISM - AGREEABLENESS
Mean Absolute Error: 1.873105582137161
Mean Squared Error: 5.602656762070465
Root Mean Squared Error: 2.366993190119157



Niska korelacja wpływa również na złe wyniki modelu uczonego za pomocą regresji liniowej, który szuka zależności między poszczególnymi cechami osobowości.

Uczenie nadzorowane

Z analizy wynika, że nie ma szczególnych zależności między pojedynczymi cechami osobowości a wynikiem gry.

Rozpatrzymy zatem cały zbiór cech i dokonamy na nich uczenia nadzorowanego - predykcji wyniku gry w klasach "low" i "medium".

Dokonujemy porównania różnych modeli.

In [24]:
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier,NearestCentroid
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn import linear_model
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn import preprocessing
from sklearn.metrics import classification_report, f1_score
import random 

data = pd.read_csv("mean_scores.csv", sep=',')
data.head()

def mapping(x):
    if x < data['Mean'].median() :
        return 'low'
    else:
        return 'high'
            
data['Mean_Class'] = data['Mean'].map(lambda x: mapping(x));


X = data[['OPENNESS', 'CONSCIENTIOUSNESS', 'EXTRAVERSION', 'AGREEABLENESS', 'NEUROTICISM']].values

y = data['Mean_Class'].values 

names = ["Nearest Centroid ", #przypisuje do obserwacji etykietę klasy próbek treningowych, których średnia jest najbliższa obserwacji
         "Nearest Neighbors", #zależność między zmiennymi objaśniającymi a objaśnianymi jest złożona lub nietypowa
         "Linear SVC", #Support Vector Classification  
         "Gaussian Process", #bazuje na aproksymacji Laplace'a - dwie klasy
         "Decision Tree", 
         "Random Forest", 
         "Naive Bayes",#Naiwne klasyfikatory bayesowskie są oparte na założeniu o wzajemnej niezależności predyktorów
         "Logistic Regression"] #zmienna zależna przyjmuje tylko dwie wartości

classifiers = [
    NearestCentroid(),
    KNeighborsClassifier(3),
    SVC(kernel="linear", C=0.025),
    GaussianProcessClassifier(1.0 * RBF(1.0)),
    DecisionTreeClassifier(max_depth=5),
    RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
    GaussianNB(),
    linear_model.LogisticRegression(solver='lbfgs')]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, shuffle=True)

#randomowy przydzial
y_pred = np.empty((y_test.size),dtype = '<U6' )
for i in range (0,y_test.size):
    y_pred[i] = random.choice(['low','high'])
print(classification_report(y_test, y_pred))
print("Średni wynik randomowego przydziału oscyluje ok.0.5")
print("")
      
for name, classifier in zip(names, classifiers):
    classifier.fit(X_train, y_train)
    # wyniki walidacji krzyzowej dla danego estymatora - określają jakość modelu
    # Testowanie każdego podzbioru używając pozostałych jako zbiór treningowy.
    # cv - określa krotność walidacji - przy niedużych zbiorach najczęściej k=10
    scores = cross_val_score(classifier, X, y, cv=10).tolist()
    print(name)
    print("Walidacja krzyżowa:", np.mean(scores))
    predictions = classifier.predict(X_test)
    print(classification_report(y_test, predictions))
    print("")
              precision    recall  f1-score   support

        high       0.52      0.54      0.53        26
         low       0.45      0.43      0.44        23

    accuracy                           0.49        49
   macro avg       0.49      0.49      0.49        49
weighted avg       0.49      0.49      0.49        49

Średni wynik randomowego przydziału oscyluje ok.0.5

Nearest Centroid 
Walidacja krzyżowa: 0.5946428571428573
              precision    recall  f1-score   support

        high       0.63      0.65      0.64        26
         low       0.59      0.57      0.58        23

    accuracy                           0.61        49
   macro avg       0.61      0.61      0.61        49
weighted avg       0.61      0.61      0.61        49


Nearest Neighbors
Walidacja krzyżowa: 0.5741071428571428
              precision    recall  f1-score   support

        high       0.71      0.65      0.68        26
         low       0.64      0.70      0.67        23

    accuracy                           0.67        49
   macro avg       0.67      0.67      0.67        49
weighted avg       0.68      0.67      0.67        49


Linear SVC
Walidacja krzyżowa: 0.6160714285714286
              precision    recall  f1-score   support

        high       0.65      0.50      0.57        26
         low       0.55      0.70      0.62        23

    accuracy                           0.59        49
   macro avg       0.60      0.60      0.59        49
weighted avg       0.60      0.59      0.59        49


Gaussian Process
Walidacja krzyżowa: 0.59375
              precision    recall  f1-score   support

        high       0.63      0.46      0.53        26
         low       0.53      0.70      0.60        23

    accuracy                           0.57        49
   macro avg       0.58      0.58      0.57        49
weighted avg       0.59      0.57      0.57        49


Decision Tree
Walidacja krzyżowa: 0.61875
              precision    recall  f1-score   support

        high       0.50      0.46      0.48        26
         low       0.44      0.48      0.46        23

    accuracy                           0.47        49
   macro avg       0.47      0.47      0.47        49
weighted avg       0.47      0.47      0.47        49


Random Forest
Walidacja krzyżowa: 0.55625
              precision    recall  f1-score   support

        high       0.52      0.50      0.51        26
         low       0.46      0.48      0.47        23

    accuracy                           0.49        49
   macro avg       0.49      0.49      0.49        49
weighted avg       0.49      0.49      0.49        49


Naive Bayes
Walidacja krzyżowa: 0.5205357142857142
              precision    recall  f1-score   support

        high       0.67      0.46      0.55        26
         low       0.55      0.74      0.63        23

    accuracy                           0.59        49
   macro avg       0.61      0.60      0.59        49
weighted avg       0.61      0.59      0.58        49


Logistic Regression
Walidacja krzyżowa: 0.6276785714285714
              precision    recall  f1-score   support

        high       0.64      0.54      0.58        26
         low       0.56      0.65      0.60        23

    accuracy                           0.59        49
   macro avg       0.60      0.60      0.59        49
weighted avg       0.60      0.59      0.59        49


Wyniki walidacji krzyżowej przekraczają 0.5, zatem można uznać, że modele całkiem dobrze radzą sobie z predykcją. Najlepsze wyniki osiągają modele: "Linear SVM", "Decision Tree" oraz "Logistic Regression".

F₁ jest kolejną miarą dokładności testu. W tym przypadku najlepsze wyniki osiągają modele: "Nearest Neighbors" oraz "Nearest Centroid".

Biorąc pod uwagę średnią obu wartości, najlepszym modelem okazuje się "Nearest Neighbors".

In [25]:
print("Porównanie wyników przy zmianie liczby sąsiadów:")
xx =[]
yy = []
for i in [1,2,3,5,10,15,20,30,40,50]:
    classifier = KNeighborsClassifier(i)
    classifier.fit(X_train, y_train)
    scores = cross_val_score(classifier, X, y, cv=10).tolist()
    predictions = classifier.predict(X_test)
    xx.append(np.mean(scores))
    yy.append(np.mean(f1_score(y_test, predictions, average=None)))

labels = [1,2,3,5,10,15,20,30,40,50]
def plot_bar_x():
    index = np.arange(len(labels))
    plt.bar(index, xx)
    plt.xlabel('Liczba sąsiadów')
    plt.xticks(index, labels, fontsize=10)
    plt.show()
def plot_bar_y():
    index = np.arange(len(labels))
    plt.bar(index, yy)
    plt.xlabel('Liczba sąsiadów')
    plt.xticks(index, labels, fontsize=10)
    plt.show()

print("")
print("Zmiana liczby sąsiadów a wartość walidacji krzyżowej:")
plot_bar_x()
print("Zmiana liczby sąsiadów a wartość F1:")
plot_bar_y()
Porównanie wyników przy zmianie liczby sąsiadów:

Zmiana liczby sąsiadów a wartość walidacji krzyżowej:
Zmiana liczby sąsiadów a wartość F1:

Z wykresów wynika, że najoptymalniejsza liczba sąsiadów znajduje się w przedziale <2,10>, ponieważ w obu przypadkach jakość modelu jest na wysokim poziomie.

In [ ]: